package com.qcrud.core.sql;

import com.qcrud.core.SqlData;
import com.qcrud.core.SqlManager;
import com.qcrud.core.parsing.ColumnInfo;
import com.qcrud.core.parsing.TableInfo;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
 * 标准 SQL 实现类，子类重写部分非标准SQL逻辑
 */
public class StandardSql implements CrudSql {
    protected static final Map<String, String> SQL_MAP = new ConcurrentHashMap<>();
    protected SqlData sqlData;
    protected TableInfo tableInfo;

    public StandardSql(SqlData sqlData) {
        this.sqlData = sqlData;
        this.tableInfo = SqlManager.getTableInfo(sqlData.getEntityClass());
    }

    @Override
    public String insert(Object entity) {
        AtomicInteger index = new AtomicInteger(1);
        tableInfo.getColumnInfos().forEach(t -> sqlData.addPosition(index.getAndIncrement(), t.fieldGet(entity)));
        return SQL_MAP.computeIfAbsent(sqlData.getStatementKey(), k -> this.getInsertSql(tableInfo, () -> tableInfo.columnJoining()));
    }

    protected String getInsertSql(TableInfo tableInfo, Supplier<String> supplier) {
        StringBuilder sql = new StringBuilder("INSERT INTO ");
        sql.append(tableInfo.getName()).append(" (");
        sql.append(tableInfo.getColumnInfos().stream().map(t -> t.getName()).collect(Collectors.joining(",")));
        sql.append(") VALUES ");
        if (null != supplier) {
            sql.append(supplier.get());
        }
        return sql.toString();
    }

    @Override
    public String insertBatch(List<Object> entityList) {
        AtomicInteger index = new AtomicInteger(1);
        entityList.forEach(e -> tableInfo.getColumnInfos().forEach(t -> sqlData.addPosition(index.getAndIncrement(), t.fieldGet(e))));
        return SQL_MAP.computeIfAbsent(sqlData.getStatementKey(), k -> this.getInsertSql(tableInfo, null)) +
            entityList.stream().map(e -> tableInfo.columnJoining()).collect(Collectors.joining(","));
    }

    @Override
    public String updateById(Object entity) {
        AtomicInteger index = new AtomicInteger(1);
        tableInfo.getColumnInfos().forEach(t -> sqlData.addPosition(index.getAndIncrement(), t.fieldGet(entity)));
        ColumnInfo idInfo = tableInfo.getIdInfo();
        sqlData.addPosition(index.getAndIncrement(), idInfo.fieldGet(entity));
        return this.getUpdateSql(idInfo, tableInfo.getColumnInfos());
    }

    protected String getUpdateSql(ColumnInfo idInfo, List<ColumnInfo> columnInfos) {
        StringBuffer sql = new StringBuffer();
        sql.append("UPDATE ").append(tableInfo.getName()).append(" SET ");
        sql.append(columnInfos.stream().map(t -> t.getName() + "=?").collect(Collectors.joining(",")));
        sql.append(" WHERE ").append(idInfo.getName()).append("=?");
        return sql.toString();
    }

    @Override
    public String updateSelectiveById(Object entity) {
        int i = 1;
        List<ColumnInfo> cis = new ArrayList<>();
        List<ColumnInfo> columnInfos = tableInfo.getColumnInfos();
        for (ColumnInfo ci : columnInfos) {
            Object obj = ci.fieldGet(entity);
            if (null != obj) {
                sqlData.addPosition(i++, obj);
                cis.add(ci);
            }
        }
        ColumnInfo idInfo = tableInfo.getIdInfo();
        sqlData.addPosition(i++, idInfo.fieldGet(entity));
        return this.getUpdateSql(idInfo, cis);
    }

    @Override
    public String deleteById(Object id) {
        sqlData.addPosition(1, id);
        return SQL_MAP.computeIfAbsent(sqlData.getStatementKey(), k -> this.getDeleteSqlWhere(() -> "=?"));
    }

    protected String getDeleteSqlWhere(Supplier<String> supplier) {
        ColumnInfo idInfo = tableInfo.getIdInfo();
        StringBuilder sql = new StringBuilder("DELETE FROM ");
        sql.append(tableInfo.getName()).append(" WHERE ");
        sql.append(idInfo.getName());
        if (null != supplier) {
            sql.append(supplier.get());
        }
        return sql.toString();
    }

    @Override
    public String deleteBatchByIds(List<Object> ids) {
        sqlData.addPosition(ids);
        return SQL_MAP.computeIfAbsent(sqlData.getStatementKey(), k -> this.getDeleteSqlWhere(null)) + in(" IN ", ids);
    }

    @Override
    public String selectById(Object id) {
        sqlData.addPosition(1, id);
        return SQL_MAP.computeIfAbsent(sqlData.getStatementKey(), k -> this.getSelectSqlWhere(() -> "=?"));
    }

    protected String getSelectSqlWhere(Supplier<String> supplier) {
        ColumnInfo idInfo = tableInfo.getIdInfo();
        StringBuilder sql = new StringBuilder("SELECT ");
        sql.append(idInfo.getName());
        if (idInfo.isConvert()) {
            sql.append(" AS ").append(idInfo.getProperty());
        }
        List<ColumnInfo> columnInfos = tableInfo.getColumnInfos();
        for (ColumnInfo columnInfo : columnInfos) {
            sql.append(",");
            sql.append(columnInfo.getName());
            if (columnInfo.isConvert()) {
                sql.append(" AS ").append(columnInfo.getProperty());
            }
        }
        sql.append(" FROM ").append(tableInfo.getName()).append(" WHERE ");
        sql.append(idInfo.getName());
        if (null != supplier) {
            sql.append(supplier.get());
        }
        return sql.toString();
    }

    @Override
    public String selectBatchByIds(List<Object> ids) {
        sqlData.addPosition(ids);
        return SQL_MAP.computeIfAbsent(sqlData.getStatementKey(), k -> this.getSelectSqlWhere(null)) + in(" IN ", ids);
    }

    /**
     * IN SQL
     *
     * @param op   操作
     * @param objs 查询对象集合
     * @return
     */
    public String in(String op, List<Object> objs) {
        StringBuffer inSql = new StringBuffer();
        inSql.append(op).append("(");
        inSql.append(objs.stream().map(t -> "?").collect(Collectors.joining(",")));
        return inSql.append(")").toString();
    }
}
